#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
题库管理功能数据库迁移脚本
添加question_banks表和修改questions表
"""

import sys
import os
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from datetime import datetime
import logging

# 添加当前目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from config import Config

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def migrate_database():
    """执行题库管理功能的数据库迁移"""
    try:
        # 创建数据库引擎
        engine = create_engine(Config.DATABASE_URL)
        SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
        
        logger.info("开始题库管理功能数据库迁移...")
        
        with engine.connect() as conn:
            # 开始事务
            trans = conn.begin()
            
            try:
                # 1. 创建question_banks表
                logger.info("创建question_banks表...")
                
                # 检查表是否已存在
                result = conn.execute(text("""
                    SELECT COUNT(*) as count FROM sqlite_master 
                    WHERE type='table' AND name='question_banks'
                """))
                table_exists = result.fetchone()[0] > 0
                
                if not table_exists:
                    # 创建question_banks表
                    conn.execute(text("""
                        CREATE TABLE question_banks (
                            id INTEGER PRIMARY KEY AUTOINCREMENT,
                            name VARCHAR NOT NULL,
                            description VARCHAR,
                            created_time DATETIME DEFAULT CURRENT_TIMESTAMP,
                            updated_time DATETIME DEFAULT CURRENT_TIMESTAMP,
                            created_by VARCHAR,
                            is_active INTEGER DEFAULT 1,
                            question_count INTEGER DEFAULT 0
                        )
                    """))
                    
                    # 创建索引
                    conn.execute(text("CREATE INDEX idx_question_banks_name ON question_banks(name)"))
                    conn.execute(text("CREATE INDEX idx_question_banks_created_time ON question_banks(created_time)"))
                    
                    logger.info("question_banks表创建完成")
                else:
                    logger.info("question_banks表已存在，跳过创建")
                
                # 2. 为questions表添加question_bank_id字段
                logger.info("为questions表添加question_bank_id字段...")
                
                # 检查字段是否已存在
                result = conn.execute(text("""
                    SELECT COUNT(*) as count FROM pragma_table_info('questions') 
                    WHERE name = 'question_bank_id'
                """))
                field_exists = result.fetchone()[0] > 0
                
                if not field_exists:
                    # 添加question_bank_id字段
                    conn.execute(text("ALTER TABLE questions ADD COLUMN question_bank_id INTEGER"))
                    
                    # 创建索引
                    conn.execute(text("CREATE INDEX idx_questions_question_bank_id ON questions(question_bank_id)"))
                    
                    logger.info("questions表question_bank_id字段添加完成")
                else:
                    logger.info("questions表question_bank_id字段已存在，跳过添加")
                
                # 3. 创建默认题库并将现有题目分配给它
                logger.info("创建默认题库...")
                
                # 检查是否已有题库
                result = conn.execute(text("SELECT COUNT(*) FROM question_banks"))
                bank_count = result.fetchone()[0]
                
                if bank_count == 0:
                    # 创建默认题库
                    now = datetime.now()
                    conn.execute(text("""
                        INSERT INTO question_banks (name, description, created_by, created_time, updated_time)
                        VALUES (:name, :description, :created_by, :created_time, :updated_time)
                    """), {
                        'name': '默认题库',
                        'description': '系统自动创建的默认题库，包含所有现有题目',
                        'created_by': 'system',
                        'created_time': now,
                        'updated_time': now
                    })
                    
                    # 获取默认题库ID
                    result = conn.execute(text("SELECT id FROM question_banks WHERE name = '默认题库'"))
                    default_bank_id = result.fetchone()[0]
                    
                    # 将所有现有题目分配给默认题库
                    conn.execute(text("""
                        UPDATE questions SET question_bank_id = :bank_id WHERE question_bank_id IS NULL
                    """), {'bank_id': default_bank_id})
                    
                    # 更新默认题库的题目数量
                    result = conn.execute(text("SELECT COUNT(*) FROM questions WHERE question_bank_id = :bank_id"), {'bank_id': default_bank_id})
                    question_count = result.fetchone()[0]
                    
                    conn.execute(text("""
                        UPDATE question_banks SET question_count = :count WHERE id = :bank_id
                    """), {'count': question_count, 'bank_id': default_bank_id})
                    
                    logger.info(f"默认题库创建完成，包含 {question_count} 道题目")
                else:
                    logger.info("题库已存在，跳过默认题库创建")
                
                # 提交事务
                trans.commit()
                logger.info("题库管理功能数据库迁移完成！")
                
            except Exception as e:
                # 回滚事务
                trans.rollback()
                logger.error(f"数据库迁移失败: {str(e)}")
                raise
                
    except Exception as e:
        logger.error(f"数据库迁移过程中发生错误: {str(e)}")
        raise

def rollback_migration():
    """回滚迁移（仅用于开发测试）"""
    try:
        engine = create_engine(Config.DATABASE_URL)
        
        logger.info("开始回滚题库管理功能迁移...")
        
        with engine.connect() as conn:
            trans = conn.begin()
            
            try:
                # 删除question_bank_id字段（SQLite不支持直接删除列，需要重建表）
                logger.info("回滚questions表修改...")
                
                # 备份数据
                conn.execute(text("""
                    CREATE TABLE questions_backup AS 
                    SELECT id, question_content, option_a, option_b, option_c, option_d, 
                           answer, knowledge_point, explanation, wrong_analysis, 
                           created_time, updated_time, difficulty_level, topic_category
                    FROM questions
                """))
                
                # 删除原表
                conn.execute(text("DROP TABLE questions"))
                
                # 重建表（不包含question_bank_id）
                conn.execute(text("""
                    CREATE TABLE questions (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        question_content VARCHAR NOT NULL,
                        option_a VARCHAR NOT NULL,
                        option_b VARCHAR NOT NULL,
                        option_c VARCHAR NOT NULL,
                        option_d VARCHAR NOT NULL,
                        answer VARCHAR NOT NULL,
                        knowledge_point VARCHAR NOT NULL,
                        explanation VARCHAR,
                        wrong_analysis VARCHAR,
                        created_time DATETIME DEFAULT CURRENT_TIMESTAMP,
                        updated_time DATETIME DEFAULT CURRENT_TIMESTAMP,
                        difficulty_level VARCHAR,
                        topic_category VARCHAR
                    )
                """))
                
                # 恢复数据
                conn.execute(text("""
                    INSERT INTO questions SELECT * FROM questions_backup
                """))
                
                # 删除备份表
                conn.execute(text("DROP TABLE questions_backup"))
                
                # 删除question_banks表
                conn.execute(text("DROP TABLE IF EXISTS question_banks"))
                
                trans.commit()
                logger.info("迁移回滚完成！")
                
            except Exception as e:
                trans.rollback()
                logger.error(f"回滚失败: {str(e)}")
                raise
                
    except Exception as e:
        logger.error(f"回滚过程中发生错误: {str(e)}")
        raise

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='题库管理功能数据库迁移')
    parser.add_argument('--rollback', action='store_true', help='回滚迁移')
    
    args = parser.parse_args()
    
    if args.rollback:
        rollback_migration()
    else:
        migrate_database()